import unittest

import torch
from torchsummary import summary

from timm.models import attention_fusion


class AttentionFusionTestCase(unittest.TestCase):

    def test_attention_fusion(self):
        self.assertTrue(torch.cuda.is_available())
        model = attention_fusion(pretrained=True)
        model = model.cuda()
        input_size = [model.default_cfg['input_size'], model.default_cfg['input_size']]
        summary(model, input_size=input_size)


if __name__ == '__main__':
    unittest.main()
